%load_ext autoreload
%autoreload 2
import banner
topics = ['Introduction to Transformers',
'Very Simple Implementation of Transformers']
banner.reset(topics)
Topics in this Notebook 1. Introduction to Transformers 2. Very Simple Implementation of Transformers
banner.next_topic()
adapted from Transformers from Scratch.
I am using Jax in this example. A recent very brief summary about Jax is at JAX: Fast as PyTorch, Simple as NumPy.
banner.next_topic()
import jax.numpy as jnp
import jax
import numpy
import pandas
import matplotlib.pyplot as plt
import time
import re # regular expressions
# jax.config.update('jax_platform_name', 'cpu')
Fake.csv
and True.csv
files from Fake and real news dataset at Kaggle.
df_fake = pandas.read_csv('Fake.csv', usecols=['title'])
df_real = pandas.read_csv('True.csv', usecols=['title'])
df_fake.shape, df_real.shape
((23481, 1), (21417, 1))
pandas.set_option('max_colwidth', None)
df_fake.head(5)
title | |
---|---|
0 | Donald Trump Sends Out Embarrassing New Year’s Eve Message; This is Disturbing |
1 | Drunk Bragging Trump Staffer Started Russian Collusion Investigation |
2 | Sheriff David Clarke Becomes An Internet Joke For Threatening To Poke People ‘In The Eye’ |
3 | Trump Is So Obsessed He Even Has Obama’s Name Coded Into His Website (IMAGES) |
4 | Pope Francis Just Called Out Donald Trump During His Christmas Speech |
df_real.head(5)
title | |
---|---|
0 | As U.S. budget fight looms, Republicans flip their fiscal script |
1 | U.S. military to accept transgender recruits on Monday: Pentagon |
2 | Senior U.S. Republican senator: 'Let Mr. Mueller do his job' |
3 | FBI Russia probe helped by Australian diplomat tip-off: NYT |
4 | Trump wants Postal Service to charge 'much more' for Amazon shipments |
keep = 1000
headlines_fake = df_fake.values[:keep]
labels_fake = numpy.zeros((headlines_fake.shape[0]))
headlines_real = df_real.values[:keep]
labels_real = numpy.ones((headlines_real.shape[0]))
headlines_fake.shape, labels_fake.shape, headlines_real.shape, labels_real.shape
((1000, 1), (1000,), (1000, 1), (1000,))
headlines_orig = headlines_fake + headlines_real
from string import punctuation
punctuation + ' '
print(punctuation)
def clean_up_words(titles):
import re
titles_words = [[word.lower() for word in re.split('\W+', title[0])] for title in titles]
# words = [w.strip(punctuation) for w in words]
words = [[w for w in title if len(w) > 1] for title in titles_words]
return words
!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
re.split('\W+', headlines_fake[0][0])
['', 'Donald', 'Trump', 'Sends', 'Out', 'Embarrassing', 'New', 'Year', 's', 'Eve', 'Message', 'This', 'is', 'Disturbing']
clean_up_words(headlines_fake[0:2])
[['donald', 'trump', 'sends', 'out', 'embarrassing', 'new', 'year', 'eve', 'message', 'this', 'is', 'disturbing'], ['drunk', 'bragging', 'trump', 'staffer', 'started', 'russian', 'collusion', 'investigation']]
headlines_fake = clean_up_words(headlines_fake)
headlines_real = clean_up_words(headlines_real)
headlines = headlines_fake + headlines_real
labels = numpy.hstack((labels_fake, labels_real))
len([len(h) for h in headlines]), len(labels)
(2000, 2000)
mx = max([len(h) for h in headlines])
mx
22
headlines = [headline + [' '] * (mx - len(headline)) for headline in headlines]
len(headlines[0]), headlines[0]
(22, ['donald', 'trump', 'sends', 'out', 'embarrassing', 'new', 'year', 'eve', 'message', 'this', 'is', 'disturbing', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' '])
words = [word for headline in headlines for word in headline]
vocabulary = numpy.unique(words)
len(vocabulary)
4581
for i in range(100):
print(i, vocabulary[i], end='; ')
print()
for i in range(4300, 4581):
print(i, vocabulary[i], end='; ')
0 ; 1 000; 2 10; 3 100k; 4 11; 5 12; 6 13; 7 14; 8 15; 9 168; 10 17; 11 18; 12 19; 13 1993; 14 1st; 15 20; 16 2016; 17 2017; 18 2018; 19 2019; 20 2020; 21 20k; 22 21; 23 22; 24 25; 25 250; 26 253; 27 26; 28 27; 29 28; 30 29; 31 2nd; 32 30; 33 30k; 34 32; 35 35; 36 375; 37 401; 38 44; 39 45; 40 46; 41 49ers; 42 50; 43 500; 44 65; 45 70; 46 700; 47 74; 48 80; 49 800; 50 81; 51 95; 52 abandon; 53 abandoned; 54 abandoning; 55 abc; 56 abducted; 57 abe; 58 ability; 59 able; 60 aborted; 61 abortion; 62 about; 63 above; 64 abroad; 65 abrupt; 66 abruptly; 67 absolute; 68 absolutely; 69 abuse; 70 abuses; 71 aca; 72 accent; 73 accept; 74 access; 75 accident; 76 accidentally; 77 accomplice; 78 accomplished; 79 accomplishes; 80 according; 81 account; 82 accountability; 83 accounts; 84 accusations; 85 accuse; 86 accused; 87 accuser; 88 accusers; 89 accuses; 90 accusing; 91 acknowledge; 92 acknowledges; 93 acre; 94 across; 95 act; 96 acting; 97 action; 98 actions; 99 activists; 4300 until; 4301 unveil; 4302 unveiled; 4303 unwanted; 4304 unworthy; 4305 up; 4306 upbeat; 4307 upcoming; 4308 upholds; 4309 upon; 4310 uprising; 4311 uproar; 4312 upset; 4313 uranium; 4314 urge; 4315 urged; 4316 urges; 4317 urging; 4318 us; 4319 usa; 4320 use; 4321 used; 4322 user; 4323 uses; 4324 usher; 4325 using; 4326 uss; 4327 utah; 4328 utility; 4329 vacation; 4330 vacationing; 4331 vacay; 4332 valley; 4333 valor; 4334 van; 4335 vanity; 4336 vast; 4337 vaunts; 4338 ve; 4339 vegas; 4340 vehicle; 4341 venezuelans; 4342 very; 4343 veteran; 4344 veterans; 4345 via; 4346 vice; 4347 viciously; 4348 victim; 4349 victims; 4350 victory; 4351 video; 4352 videos; 4353 vietnam; 4354 view; 4355 views; 4356 vile; 4357 violate; 4358 violated; 4359 violates; 4360 violating; 4361 violence; 4362 violent; 4363 violently; 4364 viral; 4365 virginia; 4366 visa; 4367 visas; 4368 visit; 4369 visitor; 4370 visits; 4371 vocal; 4372 voice; 4373 voices; 4374 vomit; 4375 vote; 4376 voted; 4377 voter; 4378 voters; 4379 votes; 4380 voting; 4381 vow; 4382 vows; 4383 waive; 4384 waivers; 4385 waiving; 4386 wake; 4387 wakes; 4388 walk; 4389 walks; 4390 wall; 4391 wallace; 4392 walls; 4393 walmart; 4394 walter; 4395 wander; 4396 want; 4397 wanted; 4398 wanting; 4399 wants; 4400 wapo; 4401 war; 4402 warmer; 4403 warming; 4404 warned; 4405 warner; 4406 warning; 4407 warnings; 4408 warns; 4409 warrant; 4410 warren; 4411 warriors; 4412 warsaw; 4413 was; 4414 washington; 4415 wasn; 4416 wastes; 4417 watch; 4418 watchdog; 4419 watches; 4420 watching; 4421 water; 4422 watergate; 4423 waters; 4424 wave; 4425 waver; 4426 way; 4427 ways; 4428 we; 4429 weakens; 4430 weakness; 4431 wealth; 4432 wealthy; 4433 weapon; 4434 weaponry; 4435 weapons; 4436 website; 4437 websites; 4438 wedding; 4439 wednesday; 4440 week; 4441 weeks; 4442 weigh; 4443 weighed; 4444 weighing; 4445 weighs; 4446 weight; 4447 weird; 4448 welcome; 4449 welfare; 4450 well; 4451 wellbeing; 4452 wells; 4453 went; 4454 were; 4455 weren; 4456 west; 4457 wets; 4458 wh; 4459 what; 4460 when; 4461 where; 4462 whether; 4463 which; 4464 while; 4465 whine; 4466 whinefest; 4467 whines; 4468 whining; 4469 whiny; 4470 whip; 4471 white; 4472 who; 4473 whoever; 4474 whole; 4475 whoop; 4476 whose; 4477 why; 4478 wide; 4479 widely; 4480 widens; 4481 widow; 4482 wife; 4483 wilbur; 4484 wildfires; 4485 wildly; 4486 will; 4487 willing; 4488 win; 4489 wind; 4490 winds; 4491 wing; 4492 winner; 4493 winners; 4494 winning; 4495 wins; 4496 wiping; 4497 wire; 4498 wisconsin; 4499 witch; 4500 with; 4501 withdraw; 4502 withdrawal; 4503 withdrawn; 4504 withdraws; 4505 withholding; 4506 within; 4507 without; 4508 witness; 4509 witnessed; 4510 witnesses; 4511 wizard; 4512 woman; 4513 women; 4514 won; 4515 woodshed; 4516 word; 4517 words; 4518 work; 4519 worked; 4520 worker; 4521 workers; 4522 working; 4523 world; 4524 worldwide; 4525 worried; 4526 worries; 4527 worry; 4528 worse; 4529 worship; 4530 worst; 4531 worth; 4532 worthy; 4533 would; 4534 wouldn; 4535 wounded; 4536 wrapped; 4537 wraps; 4538 wray; 4539 wrecked; 4540 wrecks; 4541 wrestling; 4542 write; 4543 writer; 4544 writes; 4545 wrong; 4546 wrongdoing; 4547 wrote; 4548 wsj; 4549 wtf; 4550 wwiii; 4551 xenophobic; 4552 xi; 4553 xinhua; 4554 xmas; 4555 yankuang; 4556 yates; 4557 yawns; 4558 yeah; 4559 year; 4560 years; 4561 yellen; 4562 yelling; 4563 yemen; 4564 yesterday; 4565 yet; 4566 yiannopoulos; 4567 york; 4568 yorker; 4569 you; 4570 young; 4571 your; 4572 yourself; 4573 youth; 4574 yulín; 4575 zealand; 4576 zeldin; 4577 zero; 4578 zhong; 4579 zilch; 4580 zuckerberg;
numpy.where(numpy.array(headlines[0]).reshape(-1, 1) == vocabulary)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]), array([1245, 4220, 3610, 2837, 1355, 2696, 4559, 1418, 2531, 4097, 2139, 1214, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
numpy.where(numpy.array(headlines[0]).reshape(-1, 1) == vocabulary)[1]
array([1245, 4220, 3610, 2837, 1355, 2696, 4559, 1418, 2531, 4097, 2139, 1214, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tokens = [numpy.where(numpy.array(headline).reshape(-1, 1) == vocabulary)[1] for headline in headlines]
tokens[:20]
[array([1245, 4220, 3610, 2837, 1355, 2696, 4559, 1418, 2531, 4097, 2139, 1214, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([1297, 541, 4220, 3848, 3864, 3496, 819, 2122, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([3664, 1039, 766, 406, 208, 2113, 2182, 1619, 4105, 4136, 2997, 2930, 2034, 4081, 1478, 0, 0, 0, 0, 0, 0, 0]), array([4220, 2139, 3766, 2756, 1872, 1419, 1858, 2752, 2657, 805, 2117, 1920, 4436, 2008, 0, 0, 0, 0, 0, 0, 0, 0]), array([3011, 1646, 2201, 618, 2837, 1245, 4220, 1307, 1920, 745, 3804, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([3204, 168, 938, 576, 471, 535, 4464, 1872, 2139, 2034, 1828, 1772, 2008, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([1658, 2767, 4081, 1741, 960, 4220, 2274, 2837, 309, 1522, 1127, 1172, 214, 2157, 829, 0, 0, 0, 0, 0, 0, 0]), array([4220, 3506, 3776, 2084, 3204, 3927, 2086, 4081, 2846, 2774, 214, 4510, 356, 2149, 4305, 0, 0, 0, 0, 0, 0, 0]), array([1634, 750, 1172, 3736, 4220, 2847, 4257, 587, 2807, 3953, 1872, 96, 2364, 1153, 4243, 0, 0, 0, 0, 0, 0, 0]), array([4417, 546, 2696, 3091, 4220, 104, 1527, 3766, 2631, 2238, 2149, 4486, 2441, 4569, 3695, 0, 0, 0, 0, 0, 0, 0]), array([2875, 2176, 1641, 3403, 1550, 2837, 3203, 2139, 367, 1619, 600, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([4417, 2908, 3498, 2201, 4141, 4318, 1872, 1233, 647, 62, 3924, 1504, 2386, 2034, 505, 3873, 0, 0, 0, 0, 0, 0]), array([ 367, 2699, 1619, 4220, 2579, 2494, 3536, 2712, 4136, 3343, 2753, 2034, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([4417, 2372, 1766, 4189, 2508, 1619, 3016, 4220, 287, 2251, 1630, 1920, 2851, 4517, 0, 0, 0, 0, 0, 0, 0, 0]), array([1888, 4136, 1199, 1365, 2250, 1753, 3539, 4318, 3690, 4083, 1619, 4029, 453, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([4145, 1049, 4220, 883, 3339, 3537, 2795, 2410, 4446, 138, 4569, 187, 1157, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([4081, 2113, 577, 2583, 1199, 2696, 4220, 3456, 309, 1817, 2766, 3063, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([2632, 3822, 2201, 760, 4305, 1245, 4220, 745, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([3761, 1910, 2583, 86, 731, 2591, 3477, 2605, 1619, 2410, 167, 3605, 3199, 4351, 0, 0, 0, 0, 0, 0, 0, 0]), array([3361, 3606, 1709, 1272, 1619, 1738, 138, 3454, 2632, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]
tokens[0]
array([1245, 4220, 3610, 2837, 1355, 2696, 4559, 1418, 2531, 4097, 2139, 1214, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
X_tokens = jnp.array(tokens)
X_tokens
Array([[1245, 4220, 3610, ..., 0, 0, 0], [1297, 541, 4220, ..., 0, 0, 0], [3664, 1039, 766, ..., 0, 0, 0], ..., [4126, 4048, 2652, ..., 0, 0, 0], [1437, 4464, 126, ..., 0, 0, 0], [2487, 4370, 3615, ..., 0, 0, 0]], dtype=int32)
# Create one transformer block with the following steps.
#
# 1. Make embedding layer
# 2. Make position encoding
# 3. Make weight matrices for mapping word embedding to key, query, and value.
# 4. Make weight matrix for combining output of all heads
# 5. Define forward pass for self-attention
# 6. Make weights for dense net to apply to output of self-attention.
# 7. Define forward pass through dense net.
# 8. Make weights to linearly convert output of dense net to log probs for each class.
#
# Finally, start classification.
#
# 9. Convert list of tokens for each review into their respective embeddings.
# 10. Pass each embedding through self-attention then dense net.
# 11. Calc mean over all outputs for a review then linearly reduce to log probs for each class.
# 12. Convert log probs to probs
######################################################################
# 1. Make embedding layer
n_vocabulary_words = len(vocabulary)
embed_dim = 40
embedder_W = jnp.array(numpy.random.normal(size=(n_vocabulary_words, embed_dim)))
type(embedder_W) # , embedder_W.device()
jaxlib.xla_extension.ArrayImpl
######################################################################
# 2. Make positional encoding (from http://jalammar.github.io/illustrated-transformer/)
def get_angles(pos, i, d_model):
angle_rates = 1 / numpy.power(10, (2 * (i//2)) / numpy.float32(d_model))
return pos * angle_rates
def make_position_encoding(position, d_model):
angle_rads = get_angles(numpy.arange(position)[:, numpy.newaxis],
numpy.arange(d_model)[numpy.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = numpy.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = numpy.cos(angle_rads[:, 1::2])
return jnp.array(angle_rads)
n_tokens_per_review = X_tokens.shape[1]
position_encoding = make_position_encoding(n_tokens_per_review, embed_dim)
plt.imshow(position_encoding[:, :]);
plt.axis('auto')
(-0.5, 39.5, 21.5, -0.5)
######################################################################
# 3. Make weight matrices for mapping word embedding to key, query, and value.
# 4. Make weight matrix for combining output of all heads
def make_weights(n_in, n_out):
w_scale = 1 / jnp.sqrt(n_in)
return jnp.array(numpy.random.uniform(-w_scale, w_scale, size=(n_in, n_out)))
n_heads = 8
n_in_per_head = embed_dim
W_keys = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)]
W_queries = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)]
W_values = [make_weights(n_in_per_head, n_in_per_head) for h in range(n_heads)]
W_combine = make_weights(n_heads * n_in_per_head, embed_dim)
print(f'W_key shapes {[w.shape for w in W_keys]}')
print(f'W_query shapes {[w.shape for w in W_queries]}')
print(f'W_value shapes {[w.shape for w in W_values]}')
print(f'W_combine shape {W_combine.shape}')
W_key shapes [(40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40)] W_query shapes [(40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40)] W_value shapes [(40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40), (40, 40)] W_combine shape (320, 40)
######################################################################
# 5. Define forward pass for self-attention
def softmax(Y, dim):
maxY = jnp.max(Y, axis=dim, keepdims=True)
eY = jnp.exp(Y - maxY)
eY_sum = jnp.sum(eY, axis=dim, keepdims=True)
return eY / eY_sum
def forward_attention(params, X):
embedder_W, W_keys, W_queries, W_values, W_combine, _, _, _ = params
n_samples, n_tokens, embed_dim = X.shape
# Layer Normalization pre version, as in https://arxiv.org/pdf/2002.04745.pdf
X = (X - X.mean(-1, keepdims=True)) / X.std(-1, keepdims=True)
keys = [X @ W_key for W_key in W_keys]
queries = [X @ W_query for W_query in W_queries]
values = [X @ W_value for W_value in W_values]
scale = jnp.sqrt(embed_dim)
QKs = [query @ jnp.swapaxes(key, 1, 2) / scale for query, key in zip(queries, keys)]
QKs = [softmax(QK, dim=2) for QK in QKs]
attentions = [QK @ value for QK, value in zip(QKs, values)]
attention = jnp.stack(attentions, axis=-1).reshape(n_samples, n_tokens, -1) @ W_combine
return attention, QKs
######################################################################
# 6. Make weights for dense net to apply to output of self-attention.
n_ff_units = 10
ff_W1 = make_weights(embed_dim, n_ff_units * embed_dim)
ff_W2 = make_weights(n_ff_units * embed_dim, embed_dim)
print(f'ff_W1 shape {ff_W1.shape}')
print(f'ff_W2 shape {ff_W2.shape}')
ff_W1 shape (40, 400) ff_W2 shape (400, 40)
######################################################################
# 7. Define forward pass through dense net.
def forward_transform_block(params, attention, X):
ff_W1 = params[-2]
ff_W2 = params[-1]
X = attention + X
# layernorm of X
X = (X - X.mean(-1, keepdims=True)) / X.std(-1, keepdims=True)
Y = jnp.tanh(jnp.tanh(X @ ff_W1) @ ff_W2)
return Y
######################################################################
# 8. Make weights to linearly convert output of dense net to log probs for each class.
n_classes = 2
W_toprobs = make_weights(embed_dim, n_classes)
print(f'W_toprobs shape {W_toprobs.shape}')
W_toprobs shape (40, 2)
######################################################################
# Finally, start classification.
# 9. Convert list of tokens for each review into their respective embeddings.
embedding = jnp.take(embedder_W, X_tokens, axis=0) # n_reviews x max_n_tokens x embed_dim
X = embedding + position_encoding
print(f'X.shape {X.shape}')
X.shape (2000, 22, 40)
######################################################################
# 10. Pass each embedding through self-attention then dense net.
def forward(params, X_tokens): # X is a mini-batch ,first, last):
# Each row of X is an array of token indices, one for each word
embedder_W = params[0]
X_embedding = jnp.take(embedder_W, X_tokens, axis=0) # n_reviews x max_n_tokens x embed_dim
X = X_embedding + position_encoding
Y = forward_transform_block(params, forward_attention(params, X)[0], X)
W_toprobs = params[-3]
Y = Y @ W_toprobs
# print(f'Output Y shape {Y.shape}')
Y = Y.mean(axis=1) # mean over all outputs
return Y
params = [embedder_W, W_keys, W_queries, W_values, W_combine, W_toprobs, ff_W1, ff_W2]
forward(params, X_tokens)
Array([[ 0.0766996 , -0.08933091], [ 0.16505994, -0.09998061], [ 0.10354675, -0.06341257], ..., [ 0.12111524, -0.10224821], [ 0.10253894, -0.09157991], [ 0.1086821 , -0.10692846]], dtype=float32)
labels.reshape(-1, 1) == numpy.unique(labels)
array([[ True, False], [ True, False], [ True, False], ..., [False, True], [False, True], [False, True]])
def make_indicator_vars(labels):
return jnp.array((labels.reshape(-1, 1) == numpy.unique(labels)).astype(int))
make_indicator_vars(labels)
Array([[1, 0], [1, 0], [1, 0], ..., [0, 1], [0, 1], [0, 1]], dtype=int32)
# Now we can implement the loss function, compute its gradient, and implement a training loop.
def loss(params, X_tokens, T):
Y = forward(params, X_tokens)
Y = softmax(Y, -1)
return -jnp.mean(T * jnp.log(Y))
T = make_indicator_vars(labels)
loss(params, X_tokens, T)
Array(0.35420027, dtype=float32)
def generate_stratified_partitions(X, T, n_folds, validation=True, shuffle=True):
'''Generates sets of Xtrain,Ttrain,Xvalidate,Tvalidate,Xtest,Ttest
or
sets of Xtrain,Ttrain,Xtest,Ttest if validation is False
Build dictionary keyed by class label. Each entry contains rowIndices and start and stop
indices into rowIndices for each of n_folds folds'''
def rows_in_fold(folds, k):
all_rows = []
for c, rows in folds.items():
class_rows, starts, stops = rows
all_rows += class_rows[starts[k]:stops[k]].tolist()
return all_rows
def rows_in_folds(folds, ks):
all_rows = []
for k in ks:
all_rows += rows_in_fold(folds, k)
return all_rows
row_indices = numpy.arange(X.shape[0])
if shuffle:
numpy.random.shuffle(row_indices)
folds = {}
classes = numpy.unique(T)
for c in classes:
class_indices = row_indices[numpy.where(T[row_indices, :] == c)[0]]
n_in_class = len(class_indices)
n_each = int(n_in_class / n_folds)
starts = numpy.arange(0, n_each * n_folds, n_each)
stops = starts + n_each
stops[-1] = n_in_class
folds[c] = [class_indices, starts, stops]
for test_fold in range(n_folds):
if validation:
for validate_fold in range(n_folds):
if test_fold == validate_fold:
continue
train_folds = numpy.setdiff1d(range(n_folds), [test_fold, validate_fold])
rows = rows_in_fold(folds, test_fold)
Xtest = X[rows, :]
Ttest = T[rows, :]
rows = rows_in_fold(folds, validate_fold)
Xvalidate = X[rows, :]
Tvalidate = T[rows, :]
rows = rows_in_folds(folds, train_folds)
Xtrain = X[rows, :]
Ttrain = T[rows, :]
yield Xtrain, Ttrain, Xvalidate, Tvalidate, Xtest, Ttest
else:
# No validation set
train_folds = numpy.setdiff1d(range(n_folds), [test_fold])
rows = rows_in_fold(folds, test_fold)
Xtest = X[rows, :]
Ttest = T[rows, :]
rows = rows_in_folds(folds, train_folds)
Xtrain = X[rows, :]
Ttrain = T[rows, :]
yield Xtrain, Ttrain, Xtest, Ttest
Xtrain, Ttrain, Xtest, Ttest = next(generate_stratified_partitions(X_tokens, T, 4,
validation=False, shuffle=True))
print(f'{Xtrain.shape=} {Ttrain.shape=} {Xtest.shape=} {Ttest.shape=}')
def frac_pos(T):
return (T[:, 1] == 1).mean().item()
print(f'{frac_pos(Ttrain)=:.2f} {frac_pos(Ttest)=:.2f}')
Xtrain.shape=(3000, 22) Ttrain.shape=(3000, 2) Xtest.shape=(1000, 22) Ttest.shape=(1000, 2) frac_pos(Ttrain)=0.50 frac_pos(Ttest)=0.50
loss_grad = jax.value_and_grad(loss)
def train(n_steps, batch_size, learning_rate):
global params
if batch_size < 0:
batch_size = Xtrain.shape[0]
print('Training started')
losses = []
likelihoods = []
n_samples = Xtrain.shape[0]
start_time = time.time()
for step in range(n_steps):
likelihoods_batch = []
first = 0
for batch_i, first in enumerate(range(0, n_samples, batch_size)):
Xtrain_batch = Xtrain[first:first + batch_size]
Ttrain_batch = Ttrain[first:first + batch_size]
loss_value, grads = loss_grad(params, Xtrain_batch, Ttrain_batch)
losses.append(loss_value)
likelihoods_batch.append(jnp.exp(-loss_value))
params = [param - learning_rate * grad if not isinstance(grad, list) else
[par - learning_rate * gra for (par, gra) in zip(param, grad)]
for (param, grad) in zip(params, grads)]
likelihoods.append(jnp.mean(jnp.array(likelihoods_batch))) # exp(-loss_value))
if (step + 1) % max(1, (n_steps // 20)) == 0:
print(f'Step {step+1} Likelihood {likelihoods[-1]:.4f}')
losses = jnp.array(losses)
elapsed = time.time() - start_time
print(f'Training took {elapsed/60:.1f} minutes.')
return losses, likelihoods
batch_size = -1
learning_rate = 0.1
n_steps = 500
losses, likelihoods = train(n_steps, batch_size, learning_rate)
plt.plot(likelihoods);
Training started Step 25 Likelihood 0.8401 Step 50 Likelihood 0.8416 Step 75 Likelihood 0.8453 Step 100 Likelihood 0.8468 Step 125 Likelihood 0.8503 Step 150 Likelihood 0.8519 Step 175 Likelihood 0.8552 Step 200 Likelihood 0.8566 Step 225 Likelihood 0.8598 Step 250 Likelihood 0.8612 Step 275 Likelihood 0.8644 Step 300 Likelihood 0.8657 Step 325 Likelihood 0.8690 Step 350 Likelihood 0.8702 Step 375 Likelihood 0.8736 Step 400 Likelihood 0.8747 Step 425 Likelihood 0.8783 Step 450 Likelihood 0.8793 Step 475 Likelihood 0.8831 Step 500 Likelihood 0.8842 Training took 5.7 minutes.
# To use our transformer, do a forward pass in batches, convert outputs to probabilities, then calculate a confusion matrix.
Y = []
n_samples = X_tokens.shape[0]
if batch_size > 0:
first = 0
for first in range(0, n_samples, batch_size):
Y.append(forward(params, X_tokens[first:first + batch_size]))
Y = jnp.vstack(Y)
else:
Y = forward(params, X_tokens)
probs = softmax(Y, dim=1)
probs
Array([[0.95472497, 0.04527497], [0.14120224, 0.8587978 ], [0.99820185, 0.00179817], ..., [0.09610284, 0.9038972 ], [0.82064694, 0.17935303], [0.08950872, 0.9104912 ]], dtype=float32)
pred_classes = jnp.argmax(Y, axis=1)
actual_classes = labels # jnp.argmax(labels, axis=1)
row0 = [jnp.mean(pred_classes[actual_classes == 0] == 0),
jnp.mean(pred_classes[actual_classes == 0] == 1)]
row1 = [jnp.mean(pred_classes[actual_classes == 1] == 0),
jnp.mean(pred_classes[actual_classes == 1] == 1)]
cm = jnp.array([row0, row1])
pandas.DataFrame(100*cm,
columns=('Pred Neg', 'Pred Pos'),
index=('Actual Neg', 'Actual Pos'))
Pred Neg | Pred Pos | |
---|---|---|
Actual Neg | 92.400002 | 7.6 |
Actual Pos | 14.000000 | 86.0 |
pandas.set_option('max_colwidth', None)
n = 10
df = pandas.DataFrame(numpy.hstack((actual_classes[:n].reshape(-1, 1),
pred_classes[:n].reshape(-1, 1),
headlines_orig[:n])),
columns=('Actual', 'Predicted', 'Headline'))
df
Actual | Predicted | Headline | |
---|---|---|---|
0 | 0.0 | 0 | Donald Trump Sends Out Embarrassing New Year’s Eve Message; This is DisturbingAs U.S. budget fight looms, Republicans flip their fiscal script |
1 | 0.0 | 1 | Drunk Bragging Trump Staffer Started Russian Collusion InvestigationU.S. military to accept transgender recruits on Monday: Pentagon |
2 | 0.0 | 0 | Sheriff David Clarke Becomes An Internet Joke For Threatening To Poke People ‘In The Eye’Senior U.S. Republican senator: 'Let Mr. Mueller do his job' |
3 | 0.0 | 0 | Trump Is So Obsessed He Even Has Obama’s Name Coded Into His Website (IMAGES)FBI Russia probe helped by Australian diplomat tip-off: NYT |
4 | 0.0 | 0 | Pope Francis Just Called Out Donald Trump During His Christmas SpeechTrump wants Postal Service to charge 'much more' for Amazon shipments |
5 | 0.0 | 0 | Racist Alabama Cops Brutalize Black Boy While He Is In Handcuffs (GRAPHIC IMAGES)White House, Congress prepare for talks on spending, immigration |
6 | 0.0 | 0 | Fresh Off The Golf Course, Trump Lashes Out At FBI Deputy Director And James ComeyTrump says Russia probe will be fair, but timeline unclear: NYT |
7 | 0.0 | 0 | Trump Said Some INSANELY Racist Stuff Inside The Oval Office, And Witnesses Back It UpFactbox: Trump on Twitter (Dec 29) - Approval rating, Amazon |
8 | 0.0 | 0 | Former CIA Director Slams Trump Over UN Bullying, Openly Suggests He’s Acting Like A Dictator (TWEET)Trump on Twitter (Dec 28) - Global Warming |
9 | 0.0 | 0 | WATCH: Brand-New Pro-Trump Ad Features So Much A** Kissing It Will Make You SickAlabama official to certify Senator-elect Jones today despite challenge: CNN |
pandas.set_option('max_colwidth', None)
df = pandas.DataFrame(numpy.hstack((actual_classes[-n:].reshape(-1, 1),
pred_classes[-n:].reshape(-1, 1),
headlines_orig[-n:])),
columns=('Actual', 'Predicted', 'Headline'))
df
Actual | Predicted | Headline | |
---|---|---|---|
0 | 1.0 | 1 | Spicer Quickly Orders Press Briefing To Be OFF Camera After Trump’s Indefensible Morning MeltdownWhite House says Trump will announce Fed chair pick next week |
1 | 1.0 | 1 | Trump Throws Hissy Fit After New York Times Reports That He Doesn’t Know What’s In His Own Healthcare BillU.S. seeks meeting soon to revive Asia-Pacific 'Quad' security forum |
2 | 1.0 | 1 | Republican Senator Just Had A Priest Thrown In Jail For Protesting Trumpcare (VIDEO)Trump declares opioids a U.S. public health emergency |
3 | 1.0 | 1 | Trump Is Fighting Disney Over ‘Hall of Presidents’ Attraction Robot SpeechU.S. belatedly begins to comply with Russia sanctions law |
4 | 1.0 | 1 | Sally Yates: I Refused To Lie And Say Muslim Ban Wasn’t Based On Religion When We All Know It IsTwitter bans ads from two Russian media outlets, cites election meddling |
5 | 1.0 | 1 | GOP Rep. Wants A $30k A Year Housing Allowance; Twitter RIPS Him A New OneHouse panels seek documents on Puerto Rico utility deal |
6 | 1.0 | 1 | Brace Yourself For 74 Percent Higher Health Care Premiums Under New BillTrump releases some JFK files, blocks others under pressure |
7 | 1.0 | 1 | Grandma Will Have To Pay More Than $20k A Year For Insurance Under GOP BillTillerson tells Myanmar army chief U.S. concerned about reported atrocities |
8 | 1.0 | 0 | The Absolutely Cringeworthy Moment Trump Tried Flirting With An Irish Reporter (VIDEO)Exclusive: While advising Trump in 2016, ex-CIA chief proposed plan to discredit Turkish cleric |
9 | 1.0 | 1 | A Fed Up Reporter Just Stood Up To Sarah Huckabee As She Was Smearing The Free Press (VIDEO)Mattis visits Seoul for defense talks as tensions climb |
###### attention weights
headline_i = 1900
embedder_W = params[0]
X_embedding = jnp.take(embedder_W, X_tokens[headline_i:headline_i+1], axis=0) # n_reviews x max_n_tokens x embed_dim
X = X_embedding + position_encoding # [None, :, :]
attn, QKs = forward_attention(params, X)
words = headlines[headline_i]
n_words = words.index(' ') - 1
words = words[:n_words]
# n_words = len(headlines[headline_i])
plt.figure(figsize=(12, 12))
nplot = int(numpy.sqrt(n_heads)) + 1
ploti = 0
for h in range(n_heads):
ploti += 1
plt.subplot(nplot, nplot, ploti)
plt.imshow(QKs[h][0, :n_words, :n_words])
plt.colorbar()
plt.suptitle(' '.join(words))
plt.tight_layout()
## Draw attention weights as lines between words
plt.figure(figsize=(12, 12))
plt.suptitle(' '.join(words)) # headlines_orig[headline_i][0])
ploti = 0
for h in range(n_heads):
ploti += 1
plt.subplot(nplot, nplot, ploti)
plt.xlim(0, 4)
plt.ylim(0, n_words)
plt.axis('off')
for i, w in enumerate(words):
plt.text(1, n_words - i, w, ha='right')
plt.text(3, n_words - i, w, ha='left')
QK = QKs[h][0].clone()[:n_words, :n_words]
if True:
mx = numpy.max(QK)
# mn = numpy.min(QK)
mask = QK < (0.7 * mx)
QK = QK.at[mask].set(0.0)
QK = QK / mx
for i in range(n_words):
for j in range(n_words):
plt.plot([1, 3], [n_words - i, n_words - j], 'r', alpha=QK[i, j].item())
plt.tight_layout()